import numpy as np
import sophuspy as sp
from util.imu_utils import rightJ, invRightJ

class IMUIntegrator:
    def __init__(self, prev, curr, init_bg, init_ba, init_g, args, config):
        self.prev = prev
        self.curr = curr
        self.dT = 0.
        self.g = init_g

        self.dP = np.zeros((3,))
        self.dR = sp.SO3()
        self.dV = np.zeros((3,))

        self.bg = init_bg
        self.ba = init_ba
        self.bg_new = init_bg
        self.ba_new = init_ba

        self.JPg = np.zeros((3, 3))
        self.JPa = np.zeros((3, 3))
        self.JRg = np.zeros((3, 3))
        self.JVg = np.zeros((3, 3))
        self.JVa = np.zeros((3, 3))

        FREQ = config['IMU']['frequency']
        self.Nga = np.eye(6)
        self.Nga[:3, :3] *= (config['IMU']['gyroscope_noise_density'] * np.sqrt(FREQ))**2
        self.Nga[3:, 3:] *= (config['IMU']['accelerometer_noise_density'] * np.sqrt(FREQ))**2
        self.NgaWalk = np.eye(6)
        self.NgaWalk[:3, :3] *= (config['IMU']['gyroscope_random_walk'] / np.sqrt(FREQ))**2
        self.NgaWalk[3:, 3:] *= (config['IMU']['accelerometer_random_walk'] / np.sqrt(FREQ))**2
        self.cov = np.zeros((15, 15))

    def set_new_bias(self, bias):
        # assert np.linalg.norm(bias[:3] - self.bg) < 0.1
        # assert np.linalg.norm(bias[3:] - self.ba) < 1
        self.bg_new = bias[:3]
        self.ba_new = bias[3:]

    def get_updated_dP(self):
        dbg = self.bg_new - self.bg
        dba = self.ba_new - self.ba
        return self.dP + self.JPg @ dbg + self.JPa @ dba

    def get_updated_dR(self):
        dbg = self.bg_new - self.bg
        return self.dR * sp.SO3.exp(self.JRg @ dbg)

    def get_updated_dV(self):
        dbg = self.bg_new - self.bg
        dba = self.ba_new - self.ba
        return self.dV + self.JVg @ dbg + self.JVa @ dba

    def compute_error(self, G0, G1, V0, V1, Rwg=np.eye(3), raw=False, scale=1):
        dP = self.get_updated_dP()
        dR = self.get_updated_dR()
        dV = self.get_updated_dV()
        g = Rwg @ self.g
        eP = G0[:3, :3].transpose() @ (scale*(G1[:3, 3] - G0[:3, 3] - V0*self.dT) - 0.5*g*self.dT*self.dT) - dP
        eV = G0[:3, :3].transpose() @ (scale*(V1 - V0) - g*self.dT) - dV
        eR = dR.inverse() * sp.SO3(sp.to_orthogonal(G0[:3, :3].transpose() @ G1[:3, :3]))
        if raw:
            return eR, eV, eP
        return np.concatenate([eR.log(), eV, eP])

    def jacobian(self, G0, G1, V0, V1, Rwg=np.eye(3), scale=1):
        g = Rwg @ self.g
        eR, eV, eP = self.compute_error(G0, G1, V0, V1, Rwg, True, scale)
        invJr = invRightJ(eR.log())

        Ji = np.zeros((9, 6))
        Ji[6:9,:3] = -np.eye(3) * scale
        Ji[:3,3:6] = -invJr @ G1[:3,:3].transpose() @ G0[:3,:3]
        Ji[3:6,3:6] = sp.SO3.hat(G0[:3, :3].transpose() @ (scale*(V1 - V0) - g*self.dT))
        Ji[6:9,3:6] = sp.SO3.hat(G0[:3, :3].transpose() @ (scale*(G1[:3, 3] - G0[:3, 3] - V0*self.dT) - 0.5*g*self.dT*self.dT))

        Jj = np.zeros((9, 6))
        Jj[6:9,:3] = scale*G0[:3,:3].transpose() @ G1[:3,:3]
        Jj[:3,3:6] = invJr

        Jvi = np.zeros((9, 3))
        Jvi[3:6,:] = -scale*G0[:3,:3].transpose()
        Jvi[6:9,:] = -scale*G0[:3,:3].transpose() * self.dT
        Jvj = np.zeros((9, 3))
        Jvj[3:6,:] = scale*G0[:3,:3].transpose()

        Jbg = np.zeros((9, 3))
        dbg = self.bg_new - self.bg
        Jbg[:3,:] = -invJr @ eR.matrix().transpose() @ rightJ(self.JRg @ dbg) @ self.JRg
        Jbg[3:6,:] = -self.JVg
        Jbg[6:9,:] = -self.JPg

        Jba = np.zeros((9, 3))
        Jba[3:6,:] = -self.JVa
        Jba[6:9,:] = -self.JPa

        Jscale = np.zeros((9, 1))
        Jscale[3:6,0] = G0[:3, :3].transpose() @ (V1 - V0)
        Jscale[6:9,0] = G0[:3, :3].transpose() @ (G1[:3, 3] - G0[:3, 3] - V0*self.dT)
        # somehow the minus sign is needed..
        return Ji, Jj, -Jvi, -Jvj, -Jbg, -Jba, -Jscale

    def jacobian_GDir(self, G0, G1, V0, V1, Rwg=np.eye(3)):
        Gm = np.zeros((3,2))
        Gm[0,1] = -9.81
        Gm[1,0] = 9.81
        dGdTheta = Rwg @ Gm
        Jgdir = np.zeros((9, 3))
        Jgdir[3:6,:2] = -G0[:3, :3].transpose() @ dGdTheta*self.dT
        Jgdir[6:9,:2] = -0.5*G0[:3, :3].transpose() @ dGdTheta*self.dT*self.dT
        return -Jgdir

    def integrate(self, meas):
        for i, (m0, m1) in enumerate(zip(meas[:-1], meas[1:])):
            # self.integrate_once(m0, m1, first=(i == 0), last=(i == len(meas)-2))
            if i == 0:
                self.integrate_once(m0, m1, first=True, last=False)
            elif i == len(meas)-2:
                self.integrate_once(m0, m1, first=False, last=True)
            else:
                self.integrate_once(m0, m1, first=False, last=False)
        info = np.linalg.inv(self.cov[:9, :9])
        info = (info + info.transpose()) / 2
        w, v = np.linalg.eig(info)
        self.info = v @ np.diag(w) @ v.transpose()
        self.info2 = np.linalg.inv(self.cov[9:, 9:])

    def _average(self, m0, m1, first, last):
        if first:
            tini = m0[0] - self.prev
            tab = m1[0] - m0[0]
            dt = tini + tab
            acc = (m0[4:] + m1[4:] - (m1[4:] - m0[4:])*(tini/tab)) * 0.5
            gyr = (m0[1:4] + m1[1:4] - (m1[1:4] - m0[1:4])*(tini/tab)) * 0.5
        elif last:
            tend = self.curr - m1[0]
            tab = m1[0] - m0[0]
            dt = tab + tend
            acc = (m0[4:] + m1[4:] + (m1[4:] - m0[4:])*(tend/tab)) * 0.5
            gyr = (m0[1:4] + m1[1:4] + (m1[1:4] - m0[1:4])*(tend/tab)) * 0.5
        else:
            dt = m1[0] - m0[0]
            acc = (m0[4:] + m1[4:]) * 0.5
            gyr = (m0[1:4] + m1[1:4]) * 0.5
        return dt, acc, gyr

    def integrate_once(self, m0, m1, first=False, last=False):
        dt, acc, gyr = self._average(m0, m1, first, last)

        self.dT += dt
        acc -= self.ba
        gyr -= self.bg

        acc_w = self.dR.matrix() @ acc
        prev_dR = self.dR.matrix()
        deltaR = sp.SO3.exp(dt * gyr)
        self.dP += self.dV*dt + 0.5*acc_w*dt*dt
        self.dV += acc_w*dt
        self.dR *= deltaR

        Wacc = sp.SO3.hat(acc)
        self.JPa += self.JVa*dt - 0.5*prev_dR*dt*dt
        self.JPg += self.JVg*dt - 0.5*prev_dR*dt*dt @ Wacc @ self.JRg
        self.JVa += -prev_dR*dt
        self.JVg += -prev_dR*dt @ Wacc @ self.JRg
        self.JRg = deltaR.matrix().transpose() @ self.JRg - rightJ(dt * gyr)*dt

        A = np.eye(9)
        A[3:6, :3] = -prev_dR*dt @ Wacc
        A[6:9, :3] = -0.5*prev_dR*dt*dt @ Wacc
        A[6:9, 3:6] = np.eye(3)*dt
        A[:3, :3] = deltaR.matrix().transpose()
        B = np.zeros((9, 6))
        B[3:6, 3:6] = prev_dR*dt
        B[6:9, 3:6] = 0.5*prev_dR*dt*dt
        B[:3, :3] = rightJ(dt * gyr)*dt
        self.cov[:9, :9] = A @ self.cov[:9, :9] @ A.transpose() + B @ self.Nga @ B.transpose()
        self.cov[9:, 9:] += self.NgaWalk
